# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel"""

import logging
from typing import Any, List, Tuple

import sympy

import torch
from torch._inductor.virtualized import V
from torch.utils._pytree import tree_map
from .. import config
from ..ir import (
    ComputedBuffer,
    FixedLayout,
    FlexibleLayout,
    InputBuffer,
    StorageBox,
    Subgraph,
    TensorBox,
)
from ..lowering import empty, empty_strided, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate

log = logging.getLogger(__name__)
aten = torch.ops.aten


def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
    """How is this kernel parallelized?
    We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
    Each block is responsible for iterating over blocks of keys and values calculating
    the final attention output.
    """
    import triton

    return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1)


def create_placeholder(
    name: str, dtype: torch.dtype, device: torch.device
) -> TensorBox:
    """Creates a placeholder input buffers for producing subgraph_output."""
    input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
    return TensorBox.create(input_buffer)


def build_subgraph_buffer(
    args: List[TensorBox],
    subgraph: Subgraph,
):
    """This function's goal is to take in the required args and produce the subgraph buffer
    The subgraph buffer is a ComputedBuffer that will be inlined into the triton template

    Args:
        args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
        subgraph: The Subgraph ir for which to produce the output node
    """
    cnt = 0
    env = {}
    for node in subgraph.graph_module.graph.nodes:
        # There are two classes of placeholder inpts that we need
        # to handle differently. For the first n_scalar_inps inputs
        # we expect that these placeholders were generated by the make_fx call
        # in the flex Attention HOP. So we need to create a new placeholder
        # TensorBox for each of these inputs. For the rest of the inputs we
        # expect that these are lifted inputs that fill up the '*other_buffers'
        # tuple and already have corresponding TensorBoxes passed in as args.
        if node.op == "placeholder":
            env[node] = args[cnt]
            cnt += 1
        elif node.op == "call_function":
            # For call_function we use the default lowerings and pass in the
            # already created TensorBoxes as args

            args, kwargs = tree_map(
                lambda x: env[x] if x in env else x, (node.args, node.kwargs)
            )
            env[node] = lowerings[node.target](*args, **kwargs)
        elif node.op == "output":

            def convert_output_node_to_buffer(output):
                if output is None:
                    return None
                output_node = output
                output_buffer = env[output_node]
                assert isinstance(output_buffer, TensorBox), (
                    "The output node  for flex attention's subgraph must be a TensorBox, but got: ",
                    type(output_buffer),
                )
                assert isinstance(output_buffer.data, StorageBox), (
                    "The output node for the flex attention subgraph must be a StorageBox, but got: ",
                    type(output_buffer),
                )
                subgraph_buffer = ComputedBuffer(
                    name=None,
                    layout=FlexibleLayout(
                        device=output_buffer.data.get_device(),
                        dtype=output_buffer.data.get_dtype(),
                        size=output_buffer.data.get_size(),
                    ),
                    data=output_buffer.data.data,  # type: ignore[arg-type]
                )
                return subgraph_buffer

            # node.args[0] is either a single element or a list of elements
            # representing all outputs of the function.
            return tree_map(convert_output_node_to_buffer, node.args[0])

    raise ValueError("FlexAttention was passed a subgraph with no output node!")


flex_attention_template = TritonTemplate(
    name="flex_attention",
    grid=flex_attention_grid,
    source=r"""
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "FULL_KV_IDX", "PARTIAL_KV_NUM_BLKS", "PARTIAL_KV_IDX")}}
    # Sub notation for this kernel:
    #
    # Q: Query, K: Key, V: Value
    # M: Number of queries, N: Number of keys/values, D: Model dimension
    # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
    #
    # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
    # KV_NUM_BLKS: The number of fully unmasked K/V blocks for each query.
    # FULL_KV_IDX: The indices of fully unmasked K/V blocks for each query.
    # PARTIAL_KV_NUM_BLKS: The number of partially unmasked K/V blocks for each query.
    # PARTIAL_KV_IDX: The indices of partially unmasked K/V blocks for each query.
    #
    # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
    #
    # (Modifiable) Performance tuning options
    # BLOCK_M: The thread block size across the seqlen dim of Q.
    # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.

    # The below are kernel options that can be applied for certain score_mods,
    # or involve a numerics vs. perf tradeoff
    # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
    # about 20% more numerical error, but slightly faster.
    # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
    # is not masked out? If so, we can skip an extra safety check

    tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
    tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)

    # Define strides of inputs
    stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
    stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
    stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}

    Z = {{size("Q", 0)}}
    H = {{size("Q", 1)}}
    Q_LEN = {{size("Q", 2)}}
    KV_LEN = {{size("K", 2)}}

    MATMUL_PRECISION = Q.dtype.element_ty

    q_start = tl.program_id(0)
    off_z = tl.program_id(1) // H
    off_h = tl.program_id(1) % H

    q_offset = off_z * stride_qz + off_h * stride_qh
    k_offset = off_z * stride_kz + off_h * stride_kh
    v_offset = off_z * stride_vz + off_h * stride_vh

    Q = Q + q_offset
    K = K + k_offset
    V = V + v_offset

    SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
    SPARSE_H = {{size("KV_NUM_BLKS", 1)}}

    sparse_idx_z = off_z % SPARSE_Z
    sparse_idx_h = off_h % SPARSE_H

    SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
    SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)

    SPARSE_Q_BLOCK_CNT: tl.constexpr = Q_LEN // SPARSE_Q_BLOCK_SIZE
    SPARSE_KV_BLOCK_CNT: tl.constexpr = KV_LEN // SPARSE_KV_BLOCK_SIZE

    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)

    # FULL_KV_IDX and KV_NUM_BLKS are always contiguous.
    sparse_hz_offset = sparse_idx_z * SPARSE_H + sparse_idx_h
    sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
    sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT  # noqa: B950

    Q_block_ptr = tl.make_block_ptr(
        base=Q,
        shape=(Q_LEN, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(q_start * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    )
    # load q: it stays in SRAM throughout the inner loop.
    q = tl.load(Q_block_ptr)

    # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # We don't know anything "special" about these blocks, so we need to apply
    # both score_mod and mask_fn to it
    kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
    kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
    sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)

    K_block_ptr = tl.make_block_ptr(
        base=K,
        shape=(BLOCK_DMODEL, KV_LEN),
        strides=(stride_kk, stride_kn),
        offsets=(0, kv_start),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V,
        shape=(KV_LEN, BLOCK_DMODEL),
        strides=(stride_vn, stride_vk),
        offsets=(kv_start, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    )

    acc, l_i, m_i = forward_inner(
        q, K_block_ptr, V_block_ptr,
        acc, l_i, m_i,
        off_z, off_h, offs_m,
        kv_start,
        kv_indices, sparse_kv_num_blocks,
        SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
        BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
        {{gen_argdefs()}},
        IS_FULL_BLOCKS=False
    )

    # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # We know these blocks are guaranteed to be "full", so we don't need to
    # apply mask_fn to them - only score_mod
    if HAS_FULL_BLOCKS:
        # PARTIAL_KV_IDX and PARTIAL_KV_NUM_BLKS are always contiguous.
        kv_indices = PARTIAL_KV_IDX + sparse_kv_idx_offset
        kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
        sparse_kv_num_blocks = tl.load(PARTIAL_KV_NUM_BLKS + sparse_kv_num_blks_offset)

        K_block_ptr = tl.make_block_ptr(
            base=K,
            shape=(BLOCK_DMODEL, KV_LEN),
            strides=(stride_kk, stride_kn),
            offsets=(0, kv_start),
            block_shape=(BLOCK_DMODEL, BLOCK_N),
            order=(0, 1)
        )
        V_block_ptr = tl.make_block_ptr(
            base=V,
            shape=(KV_LEN, BLOCK_DMODEL),
            strides=(stride_vn, stride_vk),
            offsets=(kv_start, 0),
            block_shape=(BLOCK_N, BLOCK_DMODEL),
            order=(1, 0)
        )
        # initialize offsets
        offs_n = kv_start + tl.arange(0, BLOCK_N)

        acc, l_i, m_i = forward_inner(
            q, K_block_ptr, V_block_ptr,
            acc, l_i, m_i,
            off_z, off_h, offs_m,
            kv_start,
            kv_indices, sparse_kv_num_blocks,
            SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
            BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
            {{gen_argdefs()}},
            IS_FULL_BLOCKS=True
        )


    # Store output and logsumexp
    acc = acc / l_i[:, None]
    idx_z = tl.program_id(1) // H
    idx_h = tl.program_id(1) % H

    offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
    idx_m = offs_m[:, None]
    idx_d = tl.arange(0, BLOCK_DMODEL)[None, :]

    mask = idx_m < Q_LEN
    # TODO generalize and add proper mask support
    {{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}

    # TODO dont want to write this if we dont require grad
    if OUTPUT_LOGSUMEXP:
        off_hz = tl.program_id(1)
        l_ptrs = LSE + off_hz * Q_LEN + offs_m
        lse = m_i + tl.math.log2(l_i)
        tl.store(l_ptrs, lse)


@triton.jit
def forward_inner(
    q, K_block_ptr, V_block_ptr,
    acc, l_i, m_i,
    off_z, off_h, offs_m,
    kv_start,
    kv_indices, sparse_kv_num_blocks,
    SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE,
    BLOCK_M, BLOCK_N, BLOCK_DMODEL, PRESCALE_QK, SM_SCALE, ROWS_GUARANTEED_SAFE, MATMUL_PRECISION,
    {{gen_argdefs()}},
    IS_FULL_BLOCKS,
):

    # initialize offsets
    offs_n = kv_start + tl.arange(0, BLOCK_N)

    RCP_LN2 = 1.44269504

    if PRESCALE_QK:
        q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

    # loop over k, v and update accumulator
    lo = 0
    hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE

    for start_n in range(0, hi):
        # -- load k --
        k = tl.load(K_block_ptr)
        # -- compute qk ---
        qk = tl.dot(q, k)
        if not PRESCALE_QK:
            qk *= SM_SCALE
        # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
        m = offs_m[:, None]
        n = offs_n[None, :]
        {{ modification(
            subgraph_number=0,
            output_name="post_mod_scores",
            score="qk",
            b="off_z",
            h="off_h",
            m="m",
            n="n",
            out="qk"
        ) | indent_except_first(2) }}

        if not IS_FULL_BLOCKS:
            {{ modification(
                subgraph_number=1,
                output_name="mask_fn_output",
                score="qk",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
            ) | indent_except_first(3) }}
            # apply mask for partially unmasked blocks
            post_mod_scores = tl.where(mask_fn_output, post_mod_scores, float("-inf"))

        # TODO: In the case that score_mod is linear, this can be LICMed
        if not PRESCALE_QK:
            post_mod_scores *= RCP_LN2
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

        # -- compute scaling constant ---
        m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))

        alpha = tl.math.exp2(m_i - m_ij)
        p = tl.math.exp2(post_mod_scores - m_ij[:, None])
        if not ROWS_GUARANTEED_SAFE:
            masked_out_rows = (m_ij == float("-inf"))
            alpha = tl.where(masked_out_rows, 0, alpha)
            p = tl.where(masked_out_rows[:, None], 0, p)

        # NB: l_i update is pulled up here since it's a bit faster
        # NB: For headdim=256, it's faster to move it back down to after m_i =
        # m_ij
        l_i = l_i * alpha + tl.sum(p, 1)
        # # -- scale and update acc --
        acc = acc * alpha[:, None]
        v = tl.load(V_block_ptr)
        acc = tl.dot(p.to(MATMUL_PRECISION), v, acc)

        # -- update m_i
        m_i = m_ij

        # update pointers
        indices_idx = start_n // SPARSE_KV_MULTIPLE

        cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
        next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_kv_num_blocks)
        needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
        jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N

        offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N

        V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, offset))

        offs_n = offs_n + offset

    return acc, l_i, m_i

 """,
)


def _use_flex_decoding(query):
    # Decide which kernel to use, return true if use flex decoding kernel.
    return V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 128))


_h100_default_config = {
    (torch.float32, 64): (128, 32, 4, 3),
    (torch.float32, 128): (32, 64, 4, 3),
    (torch.float32, 256): (32, 32, 4, 3),
    (torch.bfloat16, 64): (128, 128, 4, 3),
    (torch.bfloat16, 128): (128, 64, 8, 3),
    (torch.bfloat16, 256): (64, 32, 4, 3),
    (torch.float16, 64): (128, 128, 4, 3),
    (torch.float16, 128): (128, 128, 8, 3),
    (torch.float16, 256): (64, 32, 4, 3),
}

_a100_default_config = {
    (torch.float32, 64): (128, 32, 4, 3),
    (torch.float32, 128): (128, 32, 4, 3),
    (torch.float32, 256): (64, 16, 4, 3),
    (torch.bfloat16, 64): (128, 64, 4, 3),
    (torch.bfloat16, 128): (128, 64, 8, 3),
    (torch.bfloat16, 256): (32, 64, 4, 3),
    (torch.float16, 64): (128, 64, 4, 3),
    (torch.float16, 128): (128, 64, 8, 3),
    (torch.float16, 256): (32, 64, 4, 3),
}


def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
    dtype = query.get_dtype()
    head_dim = query.get_size()[-1]
    default_config = None

    if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0):  # H100
        if dtype == torch.float32:
            default_config = (64, 64, 4, 3)
        else:
            default_config = (128, 64, 4, 3)
        default_config = _h100_default_config.get((dtype, head_dim), default_config)
    elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0):  # A100
        if dtype == torch.float32:
            default_config = (64, 64, 4, 3)
        else:
            default_config = (128, 64, 4, 3)
        default_config = _a100_default_config.get((dtype, head_dim), default_config)
    else:  # modest hardware or extremely large head_dim
        if dtype == torch.float32:
            default_config = (32, 16, 4, 3)
        else:
            default_config = (64, 32, 4, 3)

    return default_config


def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
    head_dim = query.get_size()[-1]
    dtype = query.get_dtype()

    if dtype == torch.float32:
        return (16, 16, 4, 1)
    if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0):  # H100
        if head_dim == 64:
            return (64, 64, 4, 3)
        elif head_dim == 128:
            return (64, 128, 8, 3)
        else:
            return (64, 64, 4, 2)
    elif torch.cuda.get_device_capability() >= (8, 0):  # A100
        if head_dim == 64:
            return (32, 128, 4, 3)
        elif head_dim == 128:
            return (64, 128, 8, 3)
        else:
            return (64, 64, 4, 2)
    else:  # modest hardware or extremely large head_dim
        return (16, 16, 4, 1)


def create_num_blocks_fake_generator(sparse_indices):
    # The idea here is that we need to create a real tensor with real data
    # that's representative for benchmarking.
    # For example, returning all zeros for the `kv_num_blocks` input would mean
    # that we are computing 0 blocks for each row, which would provide bogus
    # autotuning results.
    #
    # In this case, we choose to use min(16, max_block) blocks, because I
    # (Horace) think it'll probably result in pretty representative performance.
    # If it's too short then prefetching won't help. If it's too long then
    # autotuning will take longer for no good reason.
    def create_num_blocks_fake(x) -> torch.Tensor:
        num_blocks_for_autotuning = min(16, sparse_indices.shape[-1])
        return torch.full(
            x.get_size(),
            int(num_blocks_for_autotuning),
            dtype=x.get_dtype(),
            device=x.get_device(),
        )

    return create_num_blocks_fake


def create_indices_fake(x) -> torch.Tensor:
    indices = torch.arange(
        0, int(x.get_size()[-1]), dtype=x.get_dtype(), device=x.get_device()
    )
    indices = indices.expand(x.get_size()).contiguous()
    return indices


from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel


# TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
def flex_attention(
    query,
    key,
    value,
    subgraph,
    block_mask,
    scale,
    score_mod_other_buffers,
    mask_fn_other_buffers,
):
    (
        kv_num_blocks,
        kv_indices,
        full_kv_num_blocks,
        full_kv_indices,
        q_num_blocks,
        q_indices,
        full_q_num_blocks,
        full_q_indices,
        SPARSE_KV_BLOCK_SIZE,
        SPARSE_Q_BLOCK_SIZE,
        mask_graph,
    ) = block_mask
    for buf in [
        query,
        key,
        value,
        kv_num_blocks,
        kv_indices,
        q_num_blocks,
        q_indices,
        full_kv_num_blocks,
        full_kv_indices,
        full_q_num_blocks,
        full_q_indices,
    ]:
        if buf is not None:
            buf.realize()
    placeholder_inps = [
        create_placeholder(name, dtype, query.get_device())
        for name, dtype in [
            ("score", query.get_dtype()),
            ("b", torch.int32),
            ("h", torch.int32),
            ("m", torch.int32),
            ("n", torch.int32),
        ]
    ]
    subgraph_buffer = build_subgraph_buffer(
        placeholder_inps + list(score_mod_other_buffers), subgraph
    )
    if _use_flex_decoding(query):
        return create_flex_decoding_kernel(
            subgraph_buffer,
            query,
            key,
            value,
            subgraph,
            scale,
            *score_mod_other_buffers,
        )
    mask_graph_placeholder_inps = [
        create_placeholder(name, dtype, query.get_device())
        for name, dtype in [
            ("b", torch.int32),
            ("h", torch.int32),
            ("m", torch.int32),
            ("n", torch.int32),
        ]
    ]
    mask_graph_buffer = build_subgraph_buffer(
        mask_graph_placeholder_inps + list(mask_fn_other_buffers), mask_graph
    )
    layout = FixedLayout(
        query.get_device(),
        query.get_dtype(),
        query.get_size(),
        query.get_stride(),
    )
    # see NOTE:[TritonTemplates with multiple outputs]
    logsumexp_shape = query.get_size()[:-1]  # [B, H, M]
    logsumexp = empty_strided(
        logsumexp_shape,
        None,
        dtype=torch.float32,  # The logsumexp is always stored in fp32 regardless of the input dtype
        device=query.get_device(),
    )
    # Inside of Triton kernel, only apply partial masking if partial blocks are computed.
    # full_kv_num_blocks is None if partial blocks are not computed
    has_full_blocks = full_kv_num_blocks is not None
    if full_kv_num_blocks is None:
        full_kv_num_blocks, full_kv_indices = (
            empty(0, device=query.get_device()) for _ in range(2)
        )
    choices: List[Any] = []
    configs: List[Tuple[int, int, int, int]] = []
    configs.append(_get_default_config_fwd(query))
    if config.max_autotune:
        configs += [
            (128, 64, 4, 3),
            (128, 128, 4, 3),
            (128, 128, 8, 2),
            (64, 128, 4, 3),
            (64, 64, 4, 3),
        ]

    # Note, we don't need to pass in the captured buffers explicitly
    # because they're implicitly added by the score_mod function
    # We do need to explicitly pass it in for autotuning though.

    for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
        if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0:
            continue
        # Work around https://github.com/pytorch/pytorch/issues/129625
        if num_stages == 2:
            continue

        flex_attention_template.maybe_append_choice(
            choices=choices,
            input_nodes=[
                query,
                key,
                value,
                logsumexp,
                kv_num_blocks,
                kv_indices,
                full_kv_num_blocks,
                full_kv_indices,
            ],
            layout=layout,
            subgraphs=[
                subgraph_buffer,
                mask_graph_buffer,
            ],
            mutated_inputs=[
                logsumexp,
            ],
            num_stages=num_stages,
            num_warps=num_warps,
            call_sizes=query.get_size(),
            OUTPUT_LOGSUMEXP=True,
            SM_SCALE=scale,
            BLOCK_DMODEL=query.get_size()[-1],
            # Performance tuning
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            # Blocksparse options
            SPARSE_Q_BLOCK_SIZE=SPARSE_Q_BLOCK_SIZE,
            SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE,
            # For now, we always assume the "sound" option
            ROWS_GUARANTEED_SAFE=False,
            PRESCALE_QK=False,
            HAS_FULL_BLOCKS=has_full_blocks,
        )
    inputs_for_autotuning = (
        [
            query,
            key,
            value,
            logsumexp,
            kv_num_blocks,
            kv_indices,
            full_kv_num_blocks,
            full_kv_indices,
        ]
        + list(score_mod_other_buffers)
        + list(mask_fn_other_buffers)
    )
    input_gen_fns = {
        4: create_num_blocks_fake_generator(full_kv_indices),
        5: create_indices_fake,
    }
    return (
        autotune_select_algorithm(
            "flex_attention",
            choices,
            inputs_for_autotuning,
            layout,
            input_gen_fns=input_gen_fns,
        ),
        logsumexp,
    )


# ---------------------------- Backward HOP Implementation ----------------------------


def flex_attention_backward_grid(
    batch_size, num_heads, num_queries, d_model, num_key_value, meta
):
    """How is this kernel parallelized?
    Currently this is only parallelizing over batch * num_heads, but we can, and want to
    parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require
    atomic updates to some grad values or to have a two pass kernel design.
    """
    import triton

    return (
        triton.cdiv(num_queries, meta["BLOCK_M2"])
        + triton.cdiv(num_key_value, meta["BLOCK_N1"]),
        1,
        batch_size * num_heads,
    )


flex_attention_backward_template = TritonTemplate(
    name="flex_attention_backward",
    grid=flex_attention_backward_grid,
    source=r"""
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX", "PARTIAL_KV_NUM_BLKS", "PARTIAL_KV_IDX", "PARTIAL_Q_NUM_BLKS", "PARTIAL_Q_IDX")}}
    # Sub notation for this kernel:
    #
    # Q: Query, K: Key, V: Value
    # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
    # DELTA: Precomputed sum(OUT* DO, axis=1)
    # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
    # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
    # inductor codegen
    # M: Number of queries, N: Number of keys/values, D: Model dimension
    # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
    # (Modifiable) Performance tuning options
    # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
    # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
    # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
    # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
    #
    # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
    # KV_NUM_BLKS: The number of fully unmasked K/V blocks for each query.
    # FULL_KV_IDX: The indices of fully unmasked K/V blocks for each query.
    # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks for each key/value.
    # FULL_Q_IDX: The indices of fully unmasked Q blocks for each key/value.
    # PARTIAL_KV_NUM_BLKS: The number of partially unmasked K/V blocks for each query.
    # PARTIAL_KV_IDX: The indices of partially unmasked K/V blocks for each query.
    # PARTIAL_Q_NUM_BLKS: The number of partially unmasked Q blocks for each key/value.
    # PARTIAL_Q_IDX: The indices of partially unmasked Q blocks for each key/value.

    # The below are kernel options that can be applied for certain score_mods,
    # or involve a numerics vs. perf tradeoff
    # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
    # about 20% more numerical error, but slightly faster.

    # Define strides of inputs
    stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
    stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
    stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
    stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}

    stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
    stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}

    Z = {{size("Q", 0)}}
    H = {{size("Q", 1)}}
    Q_LEN = {{size("Q", 2)}}
    KV_LEN = {{size("K", 2)}}

    MATMUL_PRECISION = Q.dtype.element_ty

    pid = tl.program_id(0)
    NUM_KV_BLOCKS = KV_LEN // BLOCK_N1

    off_hz = tl.program_id(2)
    off_z = off_hz // H # batch idx
    off_h = off_hz % H # head idx

    SM_Z = {{size("KV_NUM_BLKS", 0)}}
    SM_H = {{size("KV_NUM_BLKS", 1)}}

    sparse_idx_z = off_z % SM_Z
    sparse_idx_h = off_h % SM_H

    SPARSE_Q_BLOCK_CNT = Q_LEN // SPARSE_Q_BLOCK_SIZE
    SPARSE_KV_BLOCK_CNT = KV_LEN // SPARSE_KV_BLOCK_SIZE

    sparse_hz_offset = sparse_idx_z * SM_H + sparse_idx_h

    off_chz = (off_hz * Q_LEN).to(tl.int64)
    q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64)
    k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64)
    v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64)
    do_adj = (stride_doh * (off_hz % H) + stride_doz * (off_hz // H)).to(tl.int64)

    dq_adj = (stride_dqh * (off_hz % H) + stride_dqz * (off_hz // H)).to(tl.int64)
    dv_adj = (stride_dvh * (off_hz % H) + stride_dvz * (off_hz // H)).to(tl.int64)

    # offset pointers for batch/head
    Q += q_adj
    K += k_adj
    V += v_adj
    DO += do_adj
    # TODO: This does not work if DQ is not the same layout as Q (for example,
    # if Q is broadcasted)
    DQ += dq_adj
    DV += dv_adj
    LSE += off_chz
    DELTA += off_chz

    RCP_LN2 = 1.44269504
    offs_k = tl.arange(0, BLOCK_DMODEL)

    if pid >= NUM_KV_BLOCKS:
        off_pid = pid - NUM_KV_BLOCKS
        # THIS BLOCK DOES DQ
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)

        sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + off_pid // SPARSE_Q_MULTIPLE
        sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (off_pid // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT  # noqa: B950

        dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)

        start_m2 = off_pid * BLOCK_M2
        offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)

        # load Q and do: they stay in SRAM throughout the inner loop.
        q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
        do = tl.load(DO + offs_m2[:, None] * stride_dom + offs_k[None, :] * stride_dod)

        if PRESCALE_QK:
            q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)

        Di = tl.load(DELTA + offs_m2)
        lse = tl.load(LSE + offs_m2)
        lse = lse[:, None]

        # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # FULL_KV_IDX and KV_NUM_BLKS are always contiguous.
        kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
        kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
        sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)

        dq = bwd_dq_inner(
            dq, q, K, V, do, Di, lse,
            off_z, off_h, offs_m2, offs_k,
            stride_kn, stride_kd, stride_vn, stride_vd,
            kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
            BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
            {{gen_argdefs()}},
            IS_FULL_BLOCKS=False
        )

        if HAS_FULL_BLOCKS:
            # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # PARTIAL_KV_IDX and PARTIAL_KV_NUM_BLKS are always contiguous.
            kv_indices = PARTIAL_KV_IDX + sparse_kv_idx_offset
            kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
            sparse_kv_num_blocks = tl.load(PARTIAL_KV_NUM_BLKS + sparse_kv_num_blks_offset)

            dq = bwd_dq_inner(
                dq, q, K, V, do, Di, lse,
                off_z, off_h, offs_m2, offs_k,
                stride_kn, stride_kd, stride_vn, stride_vd,
                kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
                BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
                {{gen_argdefs()}},
                IS_FULL_BLOCKS=True
            )

        # Write back dQ.
        dq_ptrs = DQ + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
        dq *= SM_SCALE
        tl.store(dq_ptrs, dq)
    else:
        # THIS BLOCK DOES DK & DV
        SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
        SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)

        sparse_q_num_blks_offset = sparse_hz_offset * SPARSE_KV_BLOCK_CNT + pid // SPARSE_KV_MULTIPLE
        sparse_q_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (pid // SPARSE_KV_MULTIPLE) * SPARSE_Q_BLOCK_CNT  # noqa: B950

        dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
        dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)

        start_n1 = pid * BLOCK_N1
        offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)

        # load K and V: they stay in SRAM throughout the inner loop.
        k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
        if PRESCALE_QK:
            k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
        v = tl.load(V + offs_n1[:, None] * stride_vn + offs_k[None, :] * stride_vd)

        # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
        q_indices = FULL_Q_IDX + sparse_q_idx_offset
        q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
        sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)

        start_m1 = q_start

        dk, dv = bwd_dkdv_inner(
            dk, dv, Q, k, v, DO, DELTA, LSE,
            off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
            stride_qm, stride_qd, stride_dom, stride_dod,
            q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
            BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
            {{gen_argdefs()}},
            IS_FULL_BLOCKS=False
        )


        if HAS_FULL_BLOCKS:
            # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # PARTIAL_Q_IDX and PARTIAL_Q_NUM_BLKS are always contiguous.
            q_indices = PARTIAL_Q_IDX + sparse_q_idx_offset
            q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
            sparse_q_num_blocks = tl.load(PARTIAL_Q_NUM_BLKS + sparse_q_num_blks_offset)

            start_m1 = q_start

            dk, dv = bwd_dkdv_inner(
                dk, dv, Q, k, v, DO, DELTA, LSE,
                off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
                stride_qm, stride_qd, stride_dom, stride_dod,
                q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
                BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
                {{gen_argdefs()}},
                IS_FULL_BLOCKS=True
            )

        dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_k[None, :] * stride_dvd
        tl.store(dv_ptrs, dv)

        # Write back dK.
        index_n = offs_n1[:, None]
        index_k = offs_k[None, :]

        dk *= SM_SCALE
        mask = index_n <= KV_LEN
        {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}

@triton.jit
def bwd_dq_inner(
    dq, q, K, V, do, Di, lse,
    off_z, off_h, offs_m2, offs_k,
    stride_kn, stride_kd, stride_vn, stride_vd,
    kv_start, kv_indices, sparse_kv_num_blocks, SPARSE_KV_MULTIPLE, SPARSE_KV_BLOCK_SIZE,
    BLOCK_M2, BLOCK_N2, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
    {{gen_argdefs()}}, IS_FULL_BLOCKS
):
    start_n2 = kv_start
    offs_n2 = start_n2 + tl.arange(0, BLOCK_N2)
    kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
    vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_k[:, None] * stride_vd
    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)

    curr_n = start_n2
    hi = sparse_kv_num_blocks * SPARSE_KV_MULTIPLE
    for start_n in range(0, hi):
        offs_n2 = curr_n + tl.arange(0, BLOCK_N2)
        kT = tl.load(kT_ptrs)
        vT = tl.load(vT_ptrs)
        qk = tl.dot(q, kT)
        if not PRESCALE_QK:
            qk *= SM_SCALE
        # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
        pre_mod_scores = qk
        m = offs_m2[:, None]
        n = offs_n2[None, :]
        {{ modification(
            subgraph_number=0,
            output_name="post_mod_scores",
            score="qk",
            b="off_z",
            h="off_h",
            m="m",
            n="n",
            out="qk"
        ) | indent_except_first(2) }}

        if not IS_FULL_BLOCKS:
            {{ modification(
                subgraph_number=2,
                output_name="mask_fn_output",
                score="qk",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
            ) | indent_except_first(3) }}

            # apply mask for partial masked block
            post_mod_scores = tl.where(mask_fn_output, post_mod_scores, float("-inf"))
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if not PRESCALE_QK:
            post_mod_scores *= RCP_LN2
        p = tl.math.exp2(post_mod_scores - lse)
        # Compute dP and dS.
        dp = tl.dot(do, vT)
        ds = p * (dp - Di[:, None])
        # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
        {{ modification(
            subgraph_number=1,
            output_name = "grad_scores",
            score="pre_mod_scores",
            b="off_z",
            h="off_h",
            m="m",
            n="n",
            grad_score_mod="ds"
        ) | indent_except_first(2) }}
        ds = grad_scores

        if not IS_FULL_BLOCKS:
            # (grads) apply mask for partially unmasked block
            ds = tl.where(mask_fn_output, ds, 0.0)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        ds = ds.to(MATMUL_PRECISION)
        # Compute dQ.
        dq += tl.dot(ds, tl.trans(kT))

        # Increment pointers.
        indices_idx = start_n // SPARSE_KV_MULTIPLE
        cur_block = tl.load(kv_indices + indices_idx)
        next_block = tl.load(kv_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_kv_num_blocks)
        needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
        jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N2
        offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N2

        kT_ptrs += offset * stride_kn
        vT_ptrs += offset * stride_vn

        curr_n += offset

    return dq


@triton.jit
def bwd_dkdv_inner(
    dk, dv, Q, k, v, DO, DELTA, LSE,
    off_z, off_h, offs_n1, offs_k, start_n1, start_m1,
    stride_qm, stride_qd, stride_dom, stride_dod,
    q_start, q_indices, sparse_q_num_blocks, SPARSE_Q_MULTIPLE, SPARSE_Q_BLOCK_SIZE,
    BLOCK_M1, BLOCK_N1, PRESCALE_QK, SM_SCALE, RCP_LN2, MATMUL_PRECISION,
    {{gen_argdefs()}}, IS_FULL_BLOCKS
):
    offs_m1 = start_m1 + tl.arange(0, BLOCK_M1)
    offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
    qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
    do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_k[None, :] * stride_dod
    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)

    curr_m = start_m1
    hi = sparse_q_num_blocks * SPARSE_Q_MULTIPLE
    for q_start in range(0, hi):
        qT = tl.load(qT_ptrs)
        # Load LSE before computing qk to reduce pipeline stall.
        offs_m1 = curr_m + tl.arange(0, BLOCK_M1)
        lse = tl.load(LSE + offs_m1)
        qkT = tl.dot(k, qT)
        if not PRESCALE_QK:
            qkT *= SM_SCALE
        # ~~~~~~~~~~~~~~~~~~~ Apply score modification  ~~~~~~~~~~~~~~~~~~~
        m = offs_m1[None, :]
        n = offs_n1[:, None]
        pre_mod_scores = qkT
        {{ modification(
            subgraph_number=0,
            output_name="post_mod_scores",
            score="qkT",
            b="off_z",
            h="off_h",
            m="m",
            n="n",
            out="qkT"
        ) | indent_except_first(2) }}
        if not IS_FULL_BLOCKS:
            {{ modification(
                subgraph_number=2,
                output_name="mask_fn_output",
                score="qkT",
                b="off_z",
                h="off_h",
                m="m",
                n="n",
            ) | indent_except_first(3) }}
            # (grads) apply mask for fully masked block
            post_mod_scores = tl.where(mask_fn_output, post_mod_scores, float("-inf"))
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if not PRESCALE_QK:
            post_mod_scores *= RCP_LN2
        pT = tl.math.exp2(post_mod_scores - lse[None, :])
        do = tl.load(do_ptrs)
        # Compute dV.
        ppT = pT
        dv += tl.dot(ppT.to(MATMUL_PRECISION), do)
        Di = tl.load(DELTA + offs_m1)
        # Compute dP and dS.
        dpT = tl.dot(v, tl.trans(do))
        dsT = pT * (dpT - Di[None, :])
        # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
        m = offs_m1[None, :]
        n = offs_n1[:, None]
        {{ modification(
            subgraph_number=1,
            output_name = "grad_scores",
            score="pre_mod_scores",
            b="off_z",
            h="off_h",
            m="m",
            n="n",
            grad_score_mod="dsT"
        ) | indent_except_first(2) }}
        dsT = grad_scores
        if not IS_FULL_BLOCKS:
            # (grads) apply mask for partially unmasked block
            dsT = tl.where(mask_fn_output, dsT, 0.0)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))
        # Increment pointers.
        indices_idx = q_start // SPARSE_Q_MULTIPLE
        cur_block = tl.load(q_indices + indices_idx)
        next_block = tl.load(q_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_q_num_blocks)
        needs_jump = (q_start + 1) % SPARSE_Q_MULTIPLE == 0
        jump_to_block = (next_block - cur_block ) * SPARSE_Q_BLOCK_SIZE - (SPARSE_Q_MULTIPLE - 1) * BLOCK_M1
        offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_M1

        qT_ptrs += offset * stride_qm
        do_ptrs += offset * stride_dom

        curr_m += offset

    return dk, dv

 """,
)


# TODO: We probably also need a layout constraint?
@register_lowering(
    torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
)
def flex_attention_backward(*args, **kwargs):
    (
        query,
        key,
        value,
        out,
        logsumexp,
        grad_out,
        fw_graph,
        joint_graph,
        block_mask,
        scale,
        score_mod_other_buffers,
        mask_fn_other_buffers,
    ) = args
    (
        kv_num_blocks,
        kv_indices,
        full_kv_num_blocks,
        full_kv_indices,
        q_num_blocks,
        q_indices,
        full_q_num_blocks,
        full_q_indices,
        SPARSE_KV_BLOCK_SIZE,
        SPARSE_Q_BLOCK_SIZE,
        mask_graph,
    ) = block_mask

    for buf in [
        query,
        key,
        value,
        grad_out,
        kv_num_blocks,
        kv_indices,
        full_kv_num_blocks,
        full_kv_indices,
        q_num_blocks,
        q_indices,
        full_q_num_blocks,
        full_q_indices,
    ]:
        if buf is not None:
            buf.realize()

    if _use_flex_decoding(query):
        raise NotImplementedError("Flex decoding backward pass is not implemented. ")

    device = query.get_device()
    dtype = query.get_dtype()

    fwd_placeholder_inps = [
        create_placeholder(name, dtype, device)
        for name, dtype in [
            ("score", dtype),
            ("b", torch.int32),
            ("h", torch.int32),
            ("m", torch.int32),
            ("n", torch.int32),
        ]
    ]
    fw_subgraph_buffer = build_subgraph_buffer(
        fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
    )

    joint_placeholder_inps = fwd_placeholder_inps + [
        create_placeholder("grad_score_mod", dtype, device)
    ]
    joint_subgraph_buffer, *_ = build_subgraph_buffer(
        joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
    )

    mask_graph_placeholder_inps = [
        create_placeholder(name, dtype, query.get_device())
        for name, dtype in [
            ("b", torch.int32),
            ("h", torch.int32),
            ("m", torch.int32),
            ("n", torch.int32),
        ]
    ]
    mask_graph_buffer = build_subgraph_buffer(
        mask_graph_placeholder_inps + list(mask_fn_other_buffers), mask_graph
    )

    layout_k = FixedLayout(
        key.get_device(),
        key.get_dtype(),
        key.get_size(),
        key.get_stride(),
    )

    # Create delta which will is needed for the bwd's kernel
    mul_delta = lowerings[aten.mul](out, grad_out)
    delta = lowerings[aten.sum](mul_delta, axis=-1)

    # see NOTE:[TritonTemplates with multiple outputs]
    grad_query = empty_strided(
        query.get_size(), query.get_stride(), dtype=dtype, device=device
    )
    grad_value = empty_strided(
        value.get_size(), value.get_stride(), dtype=dtype, device=device
    )

    # Inside of Triton kernel, only apply partial masking if partial blocks are computed.
    # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
    has_full_blocks = full_kv_num_blocks is not None
    if full_kv_num_blocks is None:
        full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = (
            empty(0, device=query.get_device()) for _ in range(4)
        )

    choices: List[Any] = []
    configs: List[Tuple[int, int, int, int]] = []
    configs.append(_get_default_config_bwd(query))
    if config.max_autotune:
        configs.extend(
            [
                (BLOCK1, BLOCK2, w, s)
                for BLOCK1 in [32, 64]
                for BLOCK2 in [32, 64, 128]
                for w in [4, 8]
                for s in [1, 3, 4, 5]
                if BLOCK2 % BLOCK1 == 0
            ]
        )

    for BLOCK1, BLOCK2, num_warps, num_stages in configs:
        if (
            SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0
            or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0
            or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0
            or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
        ):
            continue

        flex_attention_backward_template.maybe_append_choice(
            choices=choices,
            input_nodes=[
                query,
                key,
                value,
                logsumexp,
                delta,
                grad_out,
                grad_query,
                grad_value,
                kv_num_blocks,
                kv_indices,
                q_num_blocks,
                q_indices,
                full_kv_num_blocks,
                full_kv_indices,
                full_q_num_blocks,
                full_q_indices,
            ],
            layout=layout_k,  # We use store_output only for grad_key
            subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
            mutated_inputs=[grad_query, grad_value],
            call_sizes=query.get_size() + [key.get_size()[2]],
            num_stages=num_stages,
            num_warps=num_warps,
            SM_SCALE=scale,
            BLOCK_DMODEL=query.get_size()[-1],
            # Performance tuning
            BLOCK_M1=BLOCK1,
            BLOCK_N1=BLOCK2,
            BLOCK_M2=BLOCK2,
            BLOCK_N2=BLOCK1,
            # Blocksparse options
            SPARSE_Q_BLOCK_SIZE=SPARSE_Q_BLOCK_SIZE,
            SPARSE_KV_BLOCK_SIZE=SPARSE_KV_BLOCK_SIZE,
            # For now, we always assume the "sound" option
            PRESCALE_QK=False,
            HAS_FULL_BLOCKS=has_full_blocks,
        )
    inputs_for_autotuning = (
        [
            query,
            key,
            value,
            logsumexp,
            delta,
            grad_out,
            grad_query,
            grad_value,
            kv_num_blocks,
            kv_indices,
            q_num_blocks,
            q_indices,
            full_kv_num_blocks,
            full_kv_indices,
            full_q_num_blocks,
            full_q_indices,
        ]
        + list(score_mod_other_buffers)
        + list(mask_fn_other_buffers)
    )
    input_gen_fns = {
        8: create_num_blocks_fake_generator(kv_indices),  # kv_num_blocks
        9: create_indices_fake,
        10: create_num_blocks_fake_generator(q_indices),  # q_num_blocks
        11: create_indices_fake,
        12: create_num_blocks_fake_generator(full_kv_indices),  # full_kv_num_blocks
        13: create_indices_fake,
        14: create_num_blocks_fake_generator(full_q_indices),  # full_q_num_blocks
        15: create_indices_fake,
    }

    grad_key = autotune_select_algorithm(
        "flex_attention_backward",
        choices,
        inputs_for_autotuning,
        layout_k,
        input_gen_fns=input_gen_fns,
    )
    return (
        grad_query,
        grad_key,
        grad_value,
    )
